{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import time\n",
    "from collections import defaultdict\n",
    "from torch.optim.lr_scheduler import ExponentialLR\n",
    "import argparse\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"6\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "os.path.exists('.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "sys.path.append(\"../..\") # Adds higher directory to python modules path.\n",
    "from kge.model import KgeModel\n",
    "from kge.util.io import load_checkpoint\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading configuration of dataset fbwq_full_new...\n",
      "Setting complex.relation_embedder.dropout to 0, was set to -0.4746062345802784.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "ComplEx(\n",
       "  (_entity_embedder): LookupEmbedder(\n",
       "    (_embeddings): Embedding(1886683, 256, sparse=True)\n",
       "    (dropout): Dropout(p=0.44299429655075073, inplace=False)\n",
       "  )\n",
       "  (_relation_embedder): LookupEmbedder(\n",
       "    (_embeddings): Embedding(572, 256, sparse=True)\n",
       "    (dropout): Dropout(p=0, inplace=False)\n",
       "  )\n",
       "  (_scorer): ComplExScorer()\n",
       ")"
      ]
     },
     "execution_count": 126,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "checkpoint = load_checkpoint('../../pretrained_models/embeddings/ComplEx_fbwq_full_new/checkpoint_best.pt')\n",
    "model = KgeModel.create_from(checkpoint)\n",
    "model.eval()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = open('../../data/fbwq_full_new/entity_ids.del')\n",
    "entity2idx = {}\n",
    "for line in f:\n",
    "    line = line.strip().split('\\t')\n",
    "    entity2idx[line[1]] = int(line[0])\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1255755 449\n"
     ]
    }
   ],
   "source": [
    "head = 'm.078ffw'\n",
    "rel = 'book.literary_series.works_in_this_series'\n",
    "rel_id = 0\n",
    "rels = torch.Tensor([i for i in range(548)]).long()\n",
    "for i, s in enumerate(model.dataset.relation_strings(rels)):\n",
    "    if s == rel:\n",
    "        rel_id = i\n",
    "s_id = entity2idx[head]\n",
    "print(s_id, rel_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[65.0816, 65.0736, 64.3656, 62.1204, 61.6874, 61.4585, 59.6895, 32.7663,\n",
      "         30.7057, 29.9945, 29.4358, 29.0522, 27.1528, 26.8394, 25.3394, 25.2249,\n",
      "         24.8479, 24.6993, 24.3568, 24.1571]], grad_fn=<TopkBackward>)\n",
      "tensor([[  20091,   20093,   31510,   20092,   31509,   20094,   31508,  615194,\n",
      "           78897,   28577,  955667,   78901,   33591,   59932, 1272405, 1473286,\n",
      "          246128, 1697097,   28580, 1254688]])\n",
      "['m.078ffw']\n",
      "['book.literary_series.works_in_this_series']\n",
      "[['m.06_rf9' 'm.03bkkv' 'm.015pln' 'm.01lr1g' 'm.0c_vk' 'm.014jst'\n",
      "  'm.01m5g_' 'm.026jlk9' 'm.03177r' 'm.031778' 'm.015dlb' 'm.03176f'\n",
      "  'm.03hxsv' 'm.031786' 'm.06hgrx3' 'm.06q7x14' 'm.07699t' 'm.0bqqz1h'\n",
      "  'm.031hcx' 'm.06f86m6']]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([20091])"
      ]
     },
     "execution_count": 172,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s = torch.Tensor([s_id]).long()             # subject indexes\n",
    "p = torch.Tensor([rel_id]).long()             # relation indexes\n",
    "scores = model.score_sp(s, p)                # scores of all objects for (s,p,?)\n",
    "sc, o = torch.topk(scores, 20, largest=True, dim=-1)             # index of highest-scoring objects\n",
    "\n",
    "print(sc)\n",
    "print(o)\n",
    "print(model.dataset.entity_strings(s))       # convert indexes to mentions\n",
    "print(model.dataset.relation_strings(p))\n",
    "print(model.dataset.entity_strings(o))\n",
    "\n",
    "x = ComplEx(s_id,rel_id)\n",
    "y = torch.argmax(x, dim=-1)\n",
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ComplEx(head_id, relation_id):\n",
    "    head = model._entity_embedder._embeddings(torch.Tensor([head_id]).long())\n",
    "    relation = model._relation_embedder._embeddings(torch.Tensor([relation_id]).long())\n",
    "    head = torch.stack(list(torch.chunk(head, 2, dim=1)), dim=1)\n",
    "    head = head.permute(1, 0, 2)\n",
    "    re_head = head[0]\n",
    "    im_head = head[1]\n",
    "    re_relation, im_relation = torch.chunk(relation, 2, dim=1)\n",
    "    re_tail, im_tail = torch.chunk(model._entity_embedder._embeddings.weight, 2, dim =1)\n",
    "    re_score = re_head * re_relation - im_head * im_relation\n",
    "    im_score = re_head * im_relation + im_head * re_relation\n",
    "    score = torch.stack([re_score, im_score], dim=1)\n",
    "    score = score.permute(1, 0, 2)\n",
    "    re_score = score[0]\n",
    "    im_score = score[1]\n",
    "    score = torch.mm(re_score, re_tail.transpose(1,0)) + torch.mm(im_score, im_tail.transpose(1,0))\n",
    "#     pred = torch.sigmoid(score)\n",
    "    pred = score\n",
    "    return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.9431,  1.8672, -1.2449,  0.4421,  1.6214, -1.9546, -2.1960, -2.1960,\n",
       "         -0.2738, -0.7340, -1.1670, -0.9711, -1.8178, -0.7259,  0.1633,  1.0926,\n",
       "          0.5868, -0.2913,  0.8795, -0.2848, -1.2855,  0.7986,  1.8250, -1.2966,\n",
       "         -1.3353, -0.2483,  2.0538,  0.9398,  1.8204,  0.6461, -0.3098, -1.9697,\n",
       "          1.2385,  0.7508,  1.9437,  0.6374,  2.9127,  0.1030, -3.1933,  2.9447,\n",
       "         -1.2377,  0.5684, -0.8346,  0.7918,  0.1616,  1.3663, -1.3035,  1.1497,\n",
       "         -1.9876, -1.2987,  2.4317,  0.3529, -1.0604,  1.4364,  3.0076, -1.0881,\n",
       "         -2.1577, -1.4260,  2.1145, -2.5210,  3.3457,  1.2450,  1.8183, -1.9855,\n",
       "          1.1236,  0.5548,  2.3935, -1.2452, -1.8476,  0.0859,  1.4603,  1.0376,\n",
       "         -1.4839,  2.3597,  1.2356,  0.9587, -0.5229,  0.6607, -2.2229, -1.6865,\n",
       "          1.5329,  0.4195,  1.0505, -2.1253, -3.2505,  3.5873,  0.3544,  0.8479,\n",
       "         -3.2748,  0.3019,  2.3390, -0.9631, -1.9129, -2.9487,  0.3306, -2.2707,\n",
       "          1.8452, -1.8667, -1.9532, -1.2205,  0.4988, -0.9268,  0.9562, -1.2943,\n",
       "          1.2930, -0.2998,  2.0309,  1.2275,  0.6733,  2.7224,  0.5483, -2.0283,\n",
       "          1.3185, -1.8396, -0.1920, -0.0126, -1.0119,  1.9775, -2.3338,  1.0670,\n",
       "          0.0638,  0.6885,  0.0882,  0.1418,  0.8874, -1.2143, -2.4981, -3.7135,\n",
       "          1.7738, -0.8026,  2.4396, -0.1816,  1.4465, -2.6469,  0.5176,  0.3634,\n",
       "         -0.0662, -0.8931, -0.9980,  0.6977,  0.2659,  1.7742,  0.8992, -2.0696,\n",
       "          0.7195, -2.9481, -1.7409,  3.3326, -0.6698,  2.5715,  3.1598,  1.8159,\n",
       "          1.8004, -0.7888, -1.1824, -2.5905,  1.4273, -1.6804,  1.2804, -0.5095,\n",
       "         -0.4480, -1.6269, -1.2782, -2.6434,  2.5812,  1.1123,  1.2691, -1.7744,\n",
       "         -0.6213, -1.9244, -0.9869,  0.1272,  1.1737, -2.3012, -2.7948, -0.0312,\n",
       "         -1.5392,  0.8719,  2.2756,  0.3388,  2.5851, -0.8520, -3.4109,  0.0730,\n",
       "         -1.1319, -1.8582,  2.6414, -0.0698, -2.2329,  2.9172,  1.4663, -0.4598,\n",
       "          2.4856,  1.4408, -1.1757,  0.7153,  1.4690,  0.8060, -0.0180,  0.6347,\n",
       "         -2.9257,  1.2705,  1.2692, -0.7325,  1.7617, -0.4918,  1.4028, -2.3473,\n",
       "          1.4075,  0.6586,  0.8552, -0.7256, -0.6258, -0.4710, -0.3183, -0.4455,\n",
       "         -0.0108, -0.3946,  0.8728,  2.5880, -1.5673, -0.2497, -0.9129, -0.8179,\n",
       "          0.1696,  1.6470,  0.0243, -2.9847,  1.7645, -0.3590, -2.1679,  1.6370,\n",
       "          2.7882,  0.3268,  0.2041, -1.8796,  3.0934,  2.7651,  1.0688, -2.2211,\n",
       "          0.2559, -1.9114, -1.4154,  2.7549,  0.4749, -0.5551,  0.0063, -1.1731,\n",
       "          2.4016, -1.3737, -1.2573,  0.5014,  1.5153, -1.8998, -3.1910,  1.0104]],\n",
       "       grad_fn=<EmbeddingBackward>)"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = model._entity_embedder._embeddings(torch.Tensor([91245]).long())\n",
    "r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.5567,  3.0712, -0.5995,  ..., -6.3965, -2.2072, -5.2597]],\n",
       "       grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  1.2696, -15.2891,  -9.8496,  ..., -13.2282,  -9.4417, -11.6062]],\n",
       "       grad_fn=<ViewBackward>)"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4890"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "entity2idx['m.06w2sn5'] # justin beiber"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedder = model.get_s_embedder()\n",
    "e = 'm.0chgzm'\n",
    "e_id = entity2idx[e]\n",
    "input = torch.LongTensor([e_id])\n",
    "entity_embedding = embedder._embeddings(input)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1.0806e+00, -1.3234e+00,  1.1303e+00, -9.0067e-01,  4.3825e-01,\n",
       "        -3.7421e-01,  1.2426e+00, -1.0830e+00, -9.9529e-01, -1.4535e+00,\n",
       "        -1.3703e-01,  9.4352e-01, -1.0700e+00,  7.7651e-01,  3.6658e-01,\n",
       "         2.5738e+00,  7.8383e-02,  4.6301e-02, -5.3422e-01,  2.5738e-01,\n",
       "        -4.3520e-01, -1.9415e+00,  4.9657e-04, -2.8257e-01, -7.8855e-01,\n",
       "        -1.4186e+00, -1.0457e+00,  1.3098e+00, -5.1937e-01,  4.0517e-02,\n",
       "         6.6937e-01, -6.9332e-01,  7.6408e-01, -8.2103e-02, -7.2640e-01,\n",
       "         3.2012e-01, -7.5737e-01, -1.2666e+00, -6.5312e-01, -7.2035e-01,\n",
       "        -5.7599e-01, -9.1251e-02, -2.9292e-01,  1.3930e+00, -4.7956e-01,\n",
       "         2.2868e+00, -5.8547e-01, -1.3357e+00, -8.7579e-01,  1.3007e+00,\n",
       "        -5.1788e-01,  1.4566e-01,  7.2676e-01,  5.4709e-01, -3.4058e-01,\n",
       "        -1.5802e+00,  1.3112e+00,  2.1026e-01, -1.6019e+00,  9.4673e-03,\n",
       "         6.0343e-01,  1.6433e+00, -8.9710e-01, -2.6170e-01,  1.5939e+00,\n",
       "         8.2037e-01,  3.5385e-01, -5.6484e-02, -4.9849e-03,  9.1516e-01,\n",
       "         2.4103e-01, -2.7281e-01,  1.2645e+00, -1.2036e+00, -1.2291e+00,\n",
       "        -1.2569e-02,  1.5782e+00,  1.6298e+00, -2.1577e+00, -6.1310e-01,\n",
       "        -9.1265e-01,  9.9139e-01, -1.6632e-01, -1.5676e-01,  9.3259e-02,\n",
       "         2.2529e+00,  5.8701e-01, -1.1788e+00,  1.6903e+00, -1.1746e-01,\n",
       "        -1.3222e+00, -5.6887e-01, -1.0651e+00,  6.4412e-01, -5.6297e-02,\n",
       "         7.7951e-01, -5.9284e-01,  1.5836e+00,  8.7929e-01,  3.9707e-01,\n",
       "         1.4286e-01,  1.0481e+00,  1.3899e+00,  1.2178e+00, -6.1304e-01,\n",
       "         1.4523e+00,  6.8172e-01, -5.2006e-02,  2.5794e-01, -8.3557e-01,\n",
       "        -7.3996e-01, -1.8604e-01, -4.3966e-01, -1.5502e+00, -1.3511e+00,\n",
       "        -3.0204e-01,  5.3220e-03,  2.0587e+00, -1.3807e+00, -2.1650e-01,\n",
       "         8.9722e-01, -1.2008e+00, -1.4874e+00, -1.2969e+00,  7.0116e-01,\n",
       "        -1.9287e-01, -2.7587e+00, -1.6175e+00,  9.3958e-01, -1.1519e+00,\n",
       "        -2.0105e+00,  5.9960e-01, -1.9610e+00,  1.2502e+00, -2.2405e-02,\n",
       "        -1.9187e+00,  1.4772e+00, -6.0049e-01, -1.0363e-01, -2.3792e-01,\n",
       "         1.0000e+00, -9.5154e-01,  1.3823e-01,  9.3905e-02,  6.1060e-02,\n",
       "        -1.0837e+00, -1.4198e+00,  2.7847e-02, -9.9375e-01,  6.1164e-01,\n",
       "        -3.3739e-03, -1.1730e+00, -2.0931e+00, -2.0966e-01,  1.4153e+00,\n",
       "         2.2464e-01,  4.1015e-01, -1.3295e+00, -1.4179e+00,  1.8309e+00,\n",
       "         4.0333e-01,  1.7534e-01,  1.8215e+00, -9.5153e-01,  1.4157e+00,\n",
       "        -1.0255e+00, -2.5986e-01, -1.7508e-01, -1.0257e+00, -1.4768e+00,\n",
       "        -6.9719e-01, -1.9053e+00,  1.5560e+00,  8.5259e-03, -2.4123e+00,\n",
       "         2.5731e-01,  9.9310e-01,  3.1425e-01,  2.1522e+00, -6.1987e-01,\n",
       "        -1.9276e+00,  3.8515e-01,  1.6496e-01, -7.4041e-01,  1.8468e+00,\n",
       "         1.6013e+00,  6.4383e-01,  4.3381e-01,  9.7915e-01, -4.3659e-01,\n",
       "        -2.0663e-01, -2.3951e+00, -8.2991e-01,  1.3471e+00,  1.5835e+00,\n",
       "         4.7436e-02,  1.8824e+00,  6.2090e-01,  2.1865e-01,  1.6631e+00,\n",
       "        -5.3238e-02,  1.7428e+00,  1.5286e+00,  9.5163e-03,  2.0043e-01,\n",
       "         2.4221e-01,  7.7179e-01,  9.2815e-01, -1.0607e+00,  2.8212e-01,\n",
       "         1.6627e+00, -1.4881e+00, -1.9586e-01, -7.5592e-01, -5.0581e-01,\n",
       "        -2.0315e-02, -7.6245e-01, -5.4026e-01, -1.4282e+00, -1.2166e+00,\n",
       "         1.0838e+00, -9.1178e-01,  1.3860e+00, -8.7335e-01,  1.6007e+00,\n",
       "        -7.1875e-01,  7.2386e-02,  1.4861e+00,  1.9989e+00, -1.0547e+00,\n",
       "         8.6669e-01,  4.4126e-01, -1.4857e+00,  1.7543e+00, -8.4127e-01,\n",
       "         5.6203e-01,  1.1494e+00, -1.1422e+00, -1.0936e+00, -8.1720e-02,\n",
       "         4.2083e-01, -1.7725e+00, -3.2510e-02, -7.7534e-01, -3.5580e-02,\n",
       "        -8.7881e-01,  7.1996e-01,  1.6451e+00, -5.9253e-01,  5.6696e-01,\n",
       "        -1.0577e+00,  1.6957e+00,  1.2859e+00, -7.3025e-01,  2.4445e+00,\n",
       "        -1.2166e+00], grad_fn=<SelectBackward>)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "entity_embedding[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1.0806e+00, -1.3234e+00,  1.1303e+00, -9.0067e-01,  4.3825e-01,\n",
       "        -3.7421e-01,  1.2426e+00, -1.0830e+00, -9.9529e-01, -1.4535e+00,\n",
       "        -1.3703e-01,  9.4352e-01, -1.0700e+00,  7.7651e-01,  3.6658e-01,\n",
       "         2.5738e+00,  7.8383e-02,  4.6301e-02, -5.3422e-01,  2.5738e-01,\n",
       "        -4.3520e-01, -1.9415e+00,  4.9657e-04, -2.8257e-01, -7.8855e-01,\n",
       "        -1.4186e+00, -1.0457e+00,  1.3098e+00, -5.1937e-01,  4.0517e-02,\n",
       "         6.6937e-01, -6.9332e-01,  7.6408e-01, -8.2103e-02, -7.2640e-01,\n",
       "         3.2012e-01, -7.5737e-01, -1.2666e+00, -6.5312e-01, -7.2035e-01,\n",
       "        -5.7599e-01, -9.1251e-02, -2.9292e-01,  1.3930e+00, -4.7956e-01,\n",
       "         2.2868e+00, -5.8547e-01, -1.3357e+00, -8.7579e-01,  1.3007e+00,\n",
       "        -5.1788e-01,  1.4566e-01,  7.2676e-01,  5.4709e-01, -3.4058e-01,\n",
       "        -1.5802e+00,  1.3112e+00,  2.1026e-01, -1.6019e+00,  9.4673e-03,\n",
       "         6.0343e-01,  1.6433e+00, -8.9710e-01, -2.6170e-01,  1.5939e+00,\n",
       "         8.2037e-01,  3.5385e-01, -5.6484e-02, -4.9849e-03,  9.1516e-01,\n",
       "         2.4103e-01, -2.7281e-01,  1.2645e+00, -1.2036e+00, -1.2291e+00,\n",
       "        -1.2569e-02,  1.5782e+00,  1.6298e+00, -2.1577e+00, -6.1310e-01,\n",
       "        -9.1265e-01,  9.9139e-01, -1.6632e-01, -1.5676e-01,  9.3259e-02,\n",
       "         2.2529e+00,  5.8701e-01, -1.1788e+00,  1.6903e+00, -1.1746e-01,\n",
       "        -1.3222e+00, -5.6887e-01, -1.0651e+00,  6.4412e-01, -5.6297e-02,\n",
       "         7.7951e-01, -5.9284e-01,  1.5836e+00,  8.7929e-01,  3.9707e-01,\n",
       "         1.4286e-01,  1.0481e+00,  1.3899e+00,  1.2178e+00, -6.1304e-01,\n",
       "         1.4523e+00,  6.8172e-01, -5.2006e-02,  2.5794e-01, -8.3557e-01,\n",
       "        -7.3996e-01, -1.8604e-01, -4.3966e-01, -1.5502e+00, -1.3511e+00,\n",
       "        -3.0204e-01,  5.3220e-03,  2.0587e+00, -1.3807e+00, -2.1650e-01,\n",
       "         8.9722e-01, -1.2008e+00, -1.4874e+00, -1.2969e+00,  7.0116e-01,\n",
       "        -1.9287e-01, -2.7587e+00, -1.6175e+00,  9.3958e-01, -1.1519e+00,\n",
       "        -2.0105e+00,  5.9960e-01, -1.9610e+00,  1.2502e+00, -2.2405e-02,\n",
       "        -1.9187e+00,  1.4772e+00, -6.0049e-01, -1.0363e-01, -2.3792e-01,\n",
       "         1.0000e+00, -9.5154e-01,  1.3823e-01,  9.3905e-02,  6.1060e-02,\n",
       "        -1.0837e+00, -1.4198e+00,  2.7847e-02, -9.9375e-01,  6.1164e-01,\n",
       "        -3.3739e-03, -1.1730e+00, -2.0931e+00, -2.0966e-01,  1.4153e+00,\n",
       "         2.2464e-01,  4.1015e-01, -1.3295e+00, -1.4179e+00,  1.8309e+00,\n",
       "         4.0333e-01,  1.7534e-01,  1.8215e+00, -9.5153e-01,  1.4157e+00,\n",
       "        -1.0255e+00, -2.5986e-01, -1.7508e-01, -1.0257e+00, -1.4768e+00,\n",
       "        -6.9719e-01, -1.9053e+00,  1.5560e+00,  8.5259e-03, -2.4123e+00,\n",
       "         2.5731e-01,  9.9310e-01,  3.1425e-01,  2.1522e+00, -6.1987e-01,\n",
       "        -1.9276e+00,  3.8515e-01,  1.6496e-01, -7.4041e-01,  1.8468e+00,\n",
       "         1.6013e+00,  6.4383e-01,  4.3381e-01,  9.7915e-01, -4.3659e-01,\n",
       "        -2.0663e-01, -2.3951e+00, -8.2991e-01,  1.3471e+00,  1.5835e+00,\n",
       "         4.7436e-02,  1.8824e+00,  6.2090e-01,  2.1865e-01,  1.6631e+00,\n",
       "        -5.3238e-02,  1.7428e+00,  1.5286e+00,  9.5163e-03,  2.0043e-01,\n",
       "         2.4221e-01,  7.7179e-01,  9.2815e-01, -1.0607e+00,  2.8212e-01,\n",
       "         1.6627e+00, -1.4881e+00, -1.9586e-01, -7.5592e-01, -5.0581e-01,\n",
       "        -2.0315e-02, -7.6245e-01, -5.4026e-01, -1.4282e+00, -1.2166e+00,\n",
       "         1.0838e+00, -9.1178e-01,  1.3860e+00, -8.7335e-01,  1.6007e+00,\n",
       "        -7.1875e-01,  7.2386e-02,  1.4861e+00,  1.9989e+00, -1.0547e+00,\n",
       "         8.6669e-01,  4.4126e-01, -1.4857e+00,  1.7543e+00, -8.4127e-01,\n",
       "         5.6203e-01,  1.1494e+00, -1.1422e+00, -1.0936e+00, -8.1720e-02,\n",
       "         4.2083e-01, -1.7725e+00, -3.2510e-02, -7.7534e-01, -3.5580e-02,\n",
       "        -8.7881e-01,  7.1996e-01,  1.6451e+00, -5.9253e-01,  5.6696e-01,\n",
       "        -1.0577e+00,  1.6957e+00,  1.2859e+00, -7.3025e-01,  2.4445e+00,\n",
       "        -1.2166e+00], grad_fn=<SelectBackward>)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = model._entity_embedder._embeddings.weight[0]\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.3 64-bit ('anaconda3': virtualenv)",
   "language": "python",
   "name": "python37364bitanaconda3virtualenvfe2cebba0c774becbe50d4d56f55e2fc"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
